import torch
import torchvision

vgg16 = torchvision.models.vgg16()

# 保存方式1 (模型结构+模型参数)
torch.save(vgg16, 'vgg16_method1.pth')

# 保存方式2 （模型参数) (体积更小, 推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')


# 导入模型
# 注: 若不是网络模型的话需要有原模型的class定义
vgg16_1 = torch.load('vgg16_method1.pth')

vgg16_2 = torchvision.models.vgg16()
vgg16_2.load_state_dict(torch.load('vgg16_method2.pth'))